package org.sql.flow.service;

import com.google.common.collect.Lists;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.annotation.Resource;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.sql.flow.constant.DltConstant;
import org.sql.flow.utils.DataLineageUtil;
import org.sql.flow.data.data_lineage.FieldDetail;
import org.sql.flow.data.data_lineage.FieldEdgeInfo;
import org.sql.flow.data.data_lineage.FieldLineageCanvasInfo;
import org.sql.flow.data.data_lineage.FieldNodeInfo;
import org.sql.flow.data.data_lineage.FieldNodeInfo.TableFieldDetail;
import org.sql.flow.lineage.data.ColumnInfo;
import org.sql.flow.lineage.data.DatabaseConfig;
import org.sql.flow.lineage.data.FieldInfo;
import org.sql.flow.lineage.data.FieldLineageInfo;
import org.sql.flow.lineage.enums.DataType;
import org.sql.flow.lineage.parser.DataLineageParser;
import org.sql.flow.lineage.parser.database.DatabaseFactory;
import org.sql.flow.lineage.parser.database.DatabaseService;

/**
 * @author dashuiguai
 * @create 2023-07-12 9:52
 */
@Service
public class FieldLineageService {

    private final static Logger logger = LoggerFactory.getLogger("FieldLineageService");
    public static final String DATABASE_CONFIG_ERROR = "database config error";

    @Resource
    private DataLineageParser dataLineageParser;

    @Resource
    private DatabaseFactory databaseFactory;

    public FieldLineageCanvasInfo processFieldLineageParser(String sqlType, DatabaseConfig databaseConfig, String sql) {
        List<FieldLineageInfo> fieldLineageInfoList = dataLineageParser.processFieldLineageParse(sqlType, sql, databaseConfig);

        if (CollectionUtils.isEmpty(fieldLineageInfoList)) {
            return null;
        }

        //设置默认数据库名
        String defaultDatabaseName = databaseConfig.getDatabaseName();
        if (StringUtils.isBlank(defaultDatabaseName)) {
            defaultDatabaseName = DltConstant.DEFAULT_DATABASE_NAME;
        }

        List<FieldNodeInfo> fieldNodeInfos = Lists.newArrayList();
        List<FieldEdgeInfo> fieldEdgeInfos = Lists.newArrayList();

        for (FieldLineageInfo fieldLineageInfo : fieldLineageInfoList) {
            FieldInfo downFieldInfo = fieldLineageInfo.getTargetField();
            String downDatabaseName = StringUtils.isBlank(downFieldInfo.getTableInfo().getDatabaseName()) ? defaultDatabaseName :
                    downFieldInfo.getTableInfo().getDatabaseName();
            FieldDetail target = FieldDetail.builder()
                    .databaseName(downDatabaseName)
                    .tableName(downFieldInfo.getTableInfo().getTableName())
                    .fieldName(downFieldInfo.getFieldName())
                    .build();

            FieldNodeInfo downNodeInfo = FieldNodeInfo.builder()
                    .databaseName(downDatabaseName)
                    .tableName(downFieldInfo.getTableInfo().getTableName())
                    .tableFieldDetailList(Lists.newArrayList(
                            new TableFieldDetail(downFieldInfo.getFieldName(), DataType.UNKNOWN.name())
                    ))
                    .build();
            fieldNodeInfos.add(downNodeInfo);

            for (FieldInfo upFieldInfo : fieldLineageInfo.getSourceFields()) {
                String upDatabaseName = StringUtils.isBlank(upFieldInfo.getTableInfo().getDatabaseName()) ? defaultDatabaseName :
                        upFieldInfo.getTableInfo().getDatabaseName();

                FieldDetail source = FieldDetail.builder()
                        .databaseName(upDatabaseName)
                        .tableName(upFieldInfo.getTableInfo().getTableName())
                        .fieldName(upFieldInfo.getFieldName())
                        .build();

                FieldNodeInfo upNodeInfo = FieldNodeInfo.builder()
                        .databaseName(upDatabaseName)
                        .tableName(upFieldInfo.getTableInfo().getTableName())
                        .tableFieldDetailList(Lists.newArrayList(
                                new TableFieldDetail(upFieldInfo.getFieldName(), DataType.UNKNOWN.name())
                        ))
                        .build();
                fieldNodeInfos.add(upNodeInfo);

                fieldEdgeInfos.add(
                        FieldEdgeInfo.builder()
                                .source(source)
                                .target(target)
                                .build()
                );
            }
        }
        return FieldLineageCanvasInfo.builder()
                .fieldEdgeInfoList(fieldEdgeInfos)
                .fieldNodeInfoList(fieldNodeInfos)
                .build();
    }

    public FieldLineageCanvasInfo mergeFieldCanvasInfo(
            List<FieldLineageCanvasInfo> fieldLineageCanvasInfos,
            DatabaseConfig databaseConfig,
            String sqlType
    ) {
        List<FieldNodeInfo> fieldNodeInfos = Lists.newArrayList();
        List<FieldEdgeInfo> fieldEdgeInfos = Lists.newArrayList();

        for (FieldLineageCanvasInfo fieldLineageCanvasInfo : fieldLineageCanvasInfos) {
            fieldNodeInfos.addAll(fieldLineageCanvasInfo.getFieldNodeInfoList());
            fieldEdgeInfos.addAll(fieldLineageCanvasInfo.getFieldEdgeInfoList());
        }

        //去重
        fieldEdgeInfos = fieldEdgeInfos.stream()
                .filter(DataLineageUtil.distinctByKey(FieldEdgeInfo::generateIdentity))
                .collect(Collectors.toList());

        Map<String, List<FieldNodeInfo>> nodeMap = fieldNodeInfos.stream()
                .collect(Collectors.groupingBy(FieldNodeInfo::generateIdentity));

        List<FieldNodeInfo> resultFieldNodeInfo = Lists.newArrayList();

        for (String key : nodeMap.keySet()) {
            List<FieldNodeInfo> elements = nodeMap.get(key);

            List<TableFieldDetail> tableFieldDetails = getAllFields(
                    databaseConfig,
                    sqlType,
                    elements.get(0).getDatabaseName(),
                    elements.get(0).getTableName()
            );

            if(CollectionUtils.isEmpty(tableFieldDetails)) {
                tableFieldDetails = elements.stream()
                        .map(FieldNodeInfo::getTableFieldDetailList)
                        .flatMap(Collection::stream)
                        .collect(Collectors.toList());

                //去重
                tableFieldDetails = tableFieldDetails.stream()
                        .filter(DataLineageUtil.distinctByKey(TableFieldDetail::getFieldName))
                        .collect(Collectors.toList());
            }

            resultFieldNodeInfo.add(
                    FieldNodeInfo.builder()
                            .databaseName(elements.get(0).getDatabaseName())
                            .tableName(elements.get(0).getTableName())
                            .tableFieldDetailList(tableFieldDetails)
                            .build()
            );
        }

        return FieldLineageCanvasInfo.builder()
                .fieldNodeInfoList(resultFieldNodeInfo)
                .fieldEdgeInfoList(fieldEdgeInfos)
                .build();
    }

    private List<TableFieldDetail> getAllFields(
            DatabaseConfig databaseConfig,
            String sqlType,
            String databaseName,
            String tableName
    ) {
        try {
            //未配置好相关信息
            if (StringUtils.isAnyEmpty(databaseConfig.getHost(), databaseConfig.getPassword(), databaseConfig.getUsername())) {
                return Lists.newArrayList();
            }
            DatabaseService databaseService = databaseFactory.createDatabaseService(sqlType);
            List<ColumnInfo> columnInfoList = databaseService.getAllFields(databaseConfig, databaseName, tableName);

            return columnInfoList.stream()
                    .map(element -> TableFieldDetail.builder()
                            .fieldType(databaseService.dataTypeConvert(element.getColumnType()))
                            .fieldName(element.getColumnName())
                            .build()
                    ).collect(Collectors.toList());
        } catch (Exception e) {
            logger.error(DATABASE_CONFIG_ERROR);
        }
        return Lists.newArrayList();
    }
}